from collections import deque
from typing import Hashable, Iterable, Optional

import networkx as nx
import numpy as np

from pgmpy.base._mixin_roles import _GraphRolesMixin


class AncestralBase(nx.Graph, _GraphRolesMixin):
    def __init__(
        self,
        ebunch: Optional[Iterable[tuple[Hashable, Hashable]]] = None,
        latents: set[Hashable] = set(),
        exposures: set[Hashable] = set(),
        outcomes: set[Hashable] = set(),
        roles=None,
    ):
        """
        Ancestral graph base class.
        Internally, each edge is stored with an attribute dictionary
        called ``marks``. The ``marks`` dict maps the two endpoint
        nodes to their respective marks, for example:

        - Directed: ("A", "B", "-", ">") is stored as
          ("A", "B", {"marks": {"A": "-", "B": ">"}})
        - Bidirected: ("A", "B", ">", ">") is stored as
          ("A", "B", {"marks": {"A": ">", "B": ">"}})
        - Undirected: ("A", "B", "-", "-") is stored as
          ("A", "B", {"marks": {"A": "-", "B": "-"}})
        - Circle endpoint: ("A", "B", "o", ">") is stored as
          ("A", "B", {"marks": {"A": "o", "B": ">"}})

        Parameters
        ----------
        ebunch : Iterable[tuple], optional
            An iterable of edges of the form (u, v, u_mark, v_mark) used to
            initialize the graph. Each mark must be one of {">", "-", "o"}.
            Default is None, which initializes an empty graph.

        latents : set, optional
            Set of latent (unobserved) variables in the graph. Default is
            an empty set.

        exposures : set, optional
            Set of exposure variables in the graph. These are the variables
            that represent the treatment or intervention being studied in a
            causal analysis. Default is an empty set.

        outcomes : set, optional
            Set of outcome variables in the graph. These are the variables
            that represent the response or dependent variables being studied
            in a causal analysis. Default is an empty set.

        roles : dict, optional (default: None)
            A dictionary mapping roles to node names.
            The keys are roles, and the values are role names (strings or iterables of str).
            If provided, this will automatically assign roles to the nodes in the graph.
            Passing a key-value pair via ``roles`` is equivalent to calling
            ``with_role(role, variables)`` for each key-value pair in the dictionary.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> list(graph.edges(data=True))
        [('A', 'B', {'marks': {'A': '-', 'B': '>'}}),
         ('B', 'C', {'marks': {'B': '>', 'C': '-'}})]
        >>> graph.add_edge("C", "D", "o", "o")
        >>> list(graph.edges(data=True))
        [('A', 'B', {'marks': {'A': '-', 'B': '>'}}),
         ('B', 'C', {'marks': {'B': '>', 'C': '-'}}),
         ('C', 'D', {'marks': {'C': 'o', 'D': 'o'}})]

        Roles can be assigned to nodes in the graph at construction or using methods.

        At construction:

        >>> g = AncestralBase(
        ...     ebunch=[("L", "A", "-", ">"), ("B", "C", "-", ">")],
        ...     latents={"L"},
        ...     exposures={"A"},
        ...     outcomes={"B"},
        ... )

        Roles can also be assigned after creation using ``with_role`` method.

        >>> g = g.with_role("adjustment", {"L", "C"})

        Vertices of a specific role can be retrieved using ``get_role`` method.

        >>> g.get_role("exposure")
        ["A"]
        >>> g.get_role("adjustment")
        ["L", "C"]
        """
        super().__init__()
        self.valid_marks = {">", "-", "o"}
        if ebunch:
            self.add_edges_from(ebunch)
        self.latents = set(latents)
        self.exposures = set(exposures)
        self.outcomes = set(outcomes)

        if roles is None:
            roles = {}
        elif not isinstance(roles, dict):
            raise TypeError("Roles must be provided as dictionary")

        for role, vars in roles.items():
            self.with_role(role=role, variables=vars, inplace=True)

    @property
    def adjacency_matrix(self):
        """
        Return adjacency matrix with edge marks and node-to-index mapping.

        Returns
        -------
        M : np.ndarray
            A square matrix of shape (n_nodes, n_nodes) where M[i, j]
            is the mark at node j for edge (i, j).

        node_index : dict
            Mapping from node label to row/col index.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> M, node_index = graph.adjacency_matrix
        >>> print(M)
        [[0 '>' 0]
         ['-' 0 '-']
         [0 '>' 0]]
        >>> print(node_index)
        {'A': 0, 'B': 1, 'C': 2}
        """
        nodes = list(self.nodes)
        n = len(nodes)
        node_index = {node: i for i, node in enumerate(nodes)}

        M = np.full((n, n), 0, dtype=object)

        for u, v, data in self.edges(data=True):
            u_idx, v_idx = node_index[u], node_index[v]
            u_mark = data["marks"][u]
            v_mark = data["marks"][v]

            M[u_idx, v_idx] = v_mark
            M[v_idx, u_idx] = u_mark

        return M, node_index

    @adjacency_matrix.setter
    def adjacency_matrix(self, value):
        """
        Set graph edges from an adjacency matrix with edge marks.

        Parameters
        ----------
        value : np.ndarray
            A square matrix where value[i, j] is the mark at node j
            for edge (i, j). Marks must be one of {">", "-", "o
            or 0 (no edge).

        Returns
        -------
        None

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> M = np.array([[0, ">", 0], ["-", 0, ">"], [0, "-", 0]], dtype=object)
        >>> graph = AncestralBase()
        >>> graph.adjacency_matrix = M
        >>> print(graph.nodes)
        ['X_0', 'X_1', 'X_2']
        >>> print(graph.edges(data=True))
        [('X_0', 'X_1', {'marks': {'X_1': '-', 'X_0': '>'}}), ('X_1', 'X_2', {'marks': {'X_2': '-', 'X_1': '>'}})]
        """
        value = np.asarray(value)
        if value.ndim != 2 or value.shape[0] != value.shape[1]:
            raise ValueError("Adjacency matrix must be square (n x n).")
        n = value.shape[0]
        variables = [f"X_{i}" for i in range(n)]
        self.clear()
        for i in range(n):
            for j in range(n):
                if i != j:
                    u_mark = value[i, j]
                    v_mark = value[j, i]
                    if u_mark != 0 and v_mark != 0:
                        self.add_edge(variables[i], variables[j], u_mark, v_mark)

    def add_edge(self, u, v, u_mark, v_mark):
        """
        Add an edge with specified marks.

        Parameters
        ----------
        u : Hashable
            One endpoint of the edge.

        v : Hashable
            The other endpoint of the edge.

        u_mark : str
            Mark at node u for edge (u, v). Must be one of {">", "-", "o"}.

        v_mark : str
            Mark at node v for edge (u, v). Must be one of {">",
            "-", "o"}.

        Returns
        -------
        None
            Adds the edge to the graph in-place

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> g = AncestralBase()

        # Directed edge A → B
        >>> g.add_edge("A", "B", "-", ">")
        >>> g["A"]["B"]["marks"]
        {'A': '-', 'B': '>'}

        # Bidirected edge A ↔ D
        >>> g.add_edge("A", "D", ">", ">")
        >>> g["A"]["D"]["marks"]
        {'A': '>', 'D': '>'}

        # Undirected edge C — E
        >>> g.add_edge("C", "E", "-", "-")
        >>> g["C"]["E"]["marks"]
        {'C': '-', 'E': '-'}
        """
        if u == v:
            raise ValueError("Nodes cannot be the same for an edge.")
        if u_mark not in self.valid_marks or v_mark not in self.valid_marks:
            raise ValueError(f"Marks must be one of {self.valid_marks}.")
        super().add_edge(u, v, marks={u: u_mark, v: v_mark})

    def add_edges_from(self, ebunch):
        """
        Add multiple edges from an iterable of (u, v, marks) tuples.

        Parameters
        ----------
        ebunch : Iterable[tuple]
            Each tuple should be of the form (u, v, u_mark, v_mark).

        Returns
        -------
        None
            Adds the edges to the graph in-place.


        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> g = AncestralBase()
        >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-"), ("C", "D", "o", "o")]
        >>> g.add_edges_from(edges)
        >>> list(g.edges(data=True))
        [('A', 'B', {'marks': {'A': '-', 'B': '>'}}),
         ('B', 'C', {'marks': {'B': '>', 'C': '-'}}),
         ('C', 'D', {'marks': {'C': 'o', 'D': 'o'}})]
        """
        for u, v, u_mark, v_mark in ebunch:
            self.add_edge(u, v, u_mark, v_mark)

    def get_neighbors(self, node, u_type=None, v_type=None):
        """
        Get neighbors of a node with optional edge mark constraints.

        Parameters
        ----------
        node : Hashable
            The node whose neighbors are to be found.

        u_type : Optional[str]
            Required mark at the given node for the edge.

        v_type : Optional[str]
            Required mark at the neighbor node for the edge.

        Returns
        -------
        neighbors : set
            Set of neighboring nodes satisfying the mark constraints.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", "-", ">"), ("B", "C", ">", "-"), ("C", "D", "o", "o")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_neighbors("B"))
        {'A', 'C'}
        >>> print(graph.get_neighbors("B", u_type=">"))
        {'C', 'A'}
        >>> print(graph.get_neighbors("B", v_type="-"))
        {'A', 'C'}
        >>> print(graph.get_neighbors("B", u_type=">", v_type="-"))
        {'C', 'A'}
        """
        if node not in self:
            return set()
        neighbors = set()
        for neighbor in nx.all_neighbors(self, node):

            node_mark, neighbor_mark = (
                self.edges[node, neighbor]["marks"][node],
                self.edges[node, neighbor]["marks"][neighbor],
            )

            if (u_type is None or node_mark == u_type) and (
                v_type is None or neighbor_mark == v_type
            ):
                neighbors.add(neighbor)

        return neighbors

    def get_parents(self, node):
        """
        Get nodes that point to this node with '>'

        Parameters
        ----------
        node : Hashable
            The node whose parents are to be found.

        Returns
        -------
        parents : set
            Set of parent nodes.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", "-", ">"), ("C", "B", "-", ">"), ("B", "D", "-", ">")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_parents("B"))
        {'A', 'C'}
        >>> print(graph.get_parents("D"))
        {'B'}
        >>> print(graph.get_parents("A"))
        set()
        """
        return self.get_neighbors(node, u_type=">", v_type="-")

    def get_children(self, node):
        """
        Get nodes that this node points to with '>'

        Parameters
        ----------
        node : Hashable
            The node whose children are to be found.

        Returns
        -------
        children : set
            Set of child nodes.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", "-", ">"), ("A", "C", "-", ">"), ("B", "D", "-", ">")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_children("A"))
        {'B', 'C'}
        >>> print(graph.get_children("B"))
        {'D'}
        >>> print(graph.get_children("D"))
        set()
        """
        return self.get_neighbors(node, u_type="-", v_type=">")

    def get_spouses(self, node):
        """
        Get nodes connected by bidirectional '>' edges (spouses).

        Parameters
        ----------
        node : Hashable
            The node whose spouses are to be found.

        Returns
        -------
        spouses : set
            Set of spouse nodes.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [("A", "B", ">", ">"), ("A", "C", "-", ">"), ("C", "D", ">", ">")]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_spouses("A"))
        {'B'}
        >>> print(graph.get_spouses("C"))
        {'D'}
        >>> print(graph.get_spouses("B"))
        {'A'}
        """
        return self.get_neighbors(node, u_type=">", v_type=">")

    def get_ancestors(self, node):
        """
        Get all ancestor nodes of the given node.

        Parameters
        ----------
        node : Hashable
            The node whose ancestors are to be found.

        Returns
        -------
        ancestors : set
            Set of ancestor nodes including the starting node.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [
        ...     ("A", "B", "-", ">"),
        ...     ("B", "C", "-", ">"),
        ...     ("C", "D", "-", ">"),
        ...     ("E", "C", "-", ">"),
        ... ]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_ancestors("D"))
        {'A', 'B', 'C', 'D', 'E'}
        >>> print(graph.get_ancestors("C"))
        {'A', 'B', 'C', 'E'}
        >>> print(graph.get_ancestors("A"))
        {'A'}
        """
        ancestors = set()
        visited = set()
        queue = deque(node)

        while queue:
            current = queue.popleft()
            if current not in visited:
                visited.add(current)
                ancestors.add(current)
                queue.extend(self.get_parents(current))
        return ancestors

    def get_descendants(self, node):
        """
        Get all descendant nodes (children, grandchildren, etc.)

        Parameters
        ----------
        node : Hashable
            The starting node.

        Returns
        -------
        descendants : set
            Set of descendant nodes including the starting node.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [
        ...     ("A", "B", "-", ">"),
        ...     ("B", "C", "-", ">"),
        ...     ("C", "D", "-", ">"),
        ...     ("B", "E", "-", ">"),
        ... ]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_descendants("A"))
        {'A', 'B', 'C', 'D', 'E'}
        >>> print(graph.get_descendants("B"))
        {'B', 'C', 'D', 'E'}
        >>> print(graph.get_descendants("D"))
        {'D'}
        """
        descendants = set()
        visited = set()
        queue = deque(node)

        while queue:
            current = queue.popleft()
            if current not in visited:
                visited.add(current)
                descendants.add(current)
                queue.extend(self.get_children(current))
        return descendants

    def get_reachable_nodes(self, node, u_type=None, v_type=None):
        """
        Get all nodes reachable from the given node following edges
        with specified marks.

        Parameters
        ----------
        node : Hashable
            The starting node.

        u_type : Optional[str]
            Required mark at the current node for traversal.

        v_type : Optional[str]
            Required mark at the neighbor node for traversal.

        Returns
        -------
        reachable : set
            Set of reachable nodes including the starting node.

        Examples
        --------
        >>> from pgmpy.base import AncestralBase
        >>> edges = [
        ...     ("A", "B", "-", ">"),
        ...     ("B", "C", "-", ">"),
        ...     ("A", "D", "o", "o"),
        ...     ("D", "E", "o", "o"),
        ... ]
        >>> graph = AncestralBase(ebunch=edges)
        >>> print(graph.get_reachable_nodes("A", v_type=">"))
        {'A', 'B', 'C'}
        >>> print(graph.get_reachable_nodes("A", u_type="o", v_type="o"))
        {'A', 'D', 'E'}
        """
        reachable = set()
        visited = set()
        queue = deque(node)

        while queue:
            current = queue.popleft()
            if current not in visited:
                visited.add(current)
                reachable.add(current)
                queue.extend(self.get_neighbors(current, u_type=u_type, v_type=v_type))
        return reachable

    def __eq__(self, other):
        """
        Checks if two MAGs are equal. Two MAGs are equal if they have the same
        nodes, edges(including marks), latent variables, and variable roles

        Parameters
        ----------
        other: MAG object
            The other MAG to compare with

        Returns
        -------
        bool
            True if the MAGs are equal, False otherwise

        Examples
        --------
        >>> from pgmpy.base import MAG
        >>> mag1 = MAG(
        ...     ebunch=[("X", "Y", "-", ">"), ("Y", "Z", "-", ">")],
        ...     latents={"L"},
        ...     roles={"exposure": "X"},
        ... )
        >>> mag2 = MAG(
        ...     ebunch=[("X", "Y", "-", ">"), ("Y", "Z", "-", ">")],
        ...     latents={"L"},
        ...     roles={"exposure": "X"},
        ... )
        >>> mag1 == mag2
        True

        >>> mag3 = MAG(
        ...     ebunch=[("X", "Y", "-", ">")], latents={"L"}, roles={"exposure": "X"}
        ... )
        >>> mag1 == mag3
        False
        """
        if not isinstance(other, AncestralBase):
            return False

        self_edges = {
            (u, v, frozenset(data["marks"].items()))
            for u, v, data in self.edges(data=True)
        }
        other_edges = {
            (u, v, frozenset(data["marks"].items()))
            for u, v, data in other.edges(data=True)
        }

        return (
            set(self.nodes()) == set(other.nodes())
            and self_edges == other_edges
            and self.latents == other.latents
            and self.get_role_dict() == other.get_role_dict()
        )

    def copy(self):
        """
        Return a copy of the graph, preserving nodes, edges, marks, latents, and roles.

        Returns
        -------
        AncestralBase
            A new instance of the same class as self with all properties copied.
        """
        ebunch = [
            (u, v, data["marks"][u], data["marks"][v])
            for u, v, data in self.edges(data=True)
        ]
        ancestral_base = self.__class__(
            ebunch=ebunch,
            latents=self.latents.copy(),
            exposures=self.exposures.copy(),
            outcomes=self.outcomes.copy(),
        )

        for role, vars in self.get_role_dict().items():
            ancestral_base.with_role(role=role, variables=vars, inplace=True)

        return ancestral_base
